# Spectral GCN + Attention Recovery + LSTM
# This code trains and tests the GNN model for the COVID-19 infection prediction in Tokyo
# Author: Jiawei Xue, August 26, 2021
# Step 1: read and pack the traning and testing data
# Step 2: training epoch, training process, testing
# Step 3: build the model = spectral GCN + Attention Recovery + LSTM
# Step 4: main function
# Step 5: evaluation
# Step 6: visualization
import os
import csv
import json
import copy
import time
import random
import string
import argparse
import numpy as np
import pandas as pd
import geopandas as gpd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
import torch.nn.functional as F
from Memory_spectral_T3_GCN_memory_light_1_layer import SpecGCN
from Memory_spectral_T3_GCN_memory_light_1_layer import SpecGCN_LSTM
#torch.set_printoptions(precision=8)
#hyperparameter for the setting
X_day, Y_day = 21,14
#START_DATE, END_DATE = '20200414','20210207'
#START_DATE, END_DATE = '20200808','20210603'
START_DATE, END_DATE = '20200720','20210515'
WINDOW_SIZE = 7
#hyperparameter for the learning
DROPOUT, ALPHA = 0.50, 0.20
NUM_EPOCHS, BATCH_SIZE, LEARNING_RATE = 100, 8, 0.0001
HIDDEN_DIM_1, OUT_DIM_1, HIDDEN_DIM_2 = 6,4,3
infection_normalize_ratio = 100.0
web_search_normalize_ratio = 100.0
train_ratio = 0.7
validate_ratio = 0.1
SEED = 5
torch.manual_seed(SEED)
r = random.random
random.seed(5)
#1.total period (mobility+text):
#from 20200201 to 20210620: (29+31+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20)\
#= 335 + 171 = 506;
#2.number of zones: 23;
#3.infection period:
#20200331 to 20210620: (1+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20) = 276 + 171 = 447.
#1. Mobility: functions 1.2 to 1.7
#2. Text: functions 1.8 to 1.14
#3. InfectionL: functions 1.15
#4. Preprocess: functions 1.16 to 1.24
#5. Learn: functions 1.25 to 1.26
#function 1.1
#get the central areas of Tokyo (e.g., the Special wards of Tokyo)
#return: a 23 zone shapefile
def read_tokyo_23():
folder = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/present_model_version10/tokyo_23"
file = "tokyo_23zones.shp"
path = os.path.join(folder,file)
data = gpd.read_file(path)
return data
##################1.Mobility#####################
#function 1.2
#get the average of two days' mobility (infection) records
def mob_inf_average(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
new_record[i] = (record1[i]+record2[i])/2.0
return new_record
#function 1.3
#get the average of multiple days' mobility (infection) records
def mob_inf_average_multiple(data, keyList):
new_record = dict()
num_day = len(keyList)
for i in range(num_day):
record = data[keyList[i]]
for zone_id in record:
if zone_id not in list(new_record.keys()):
new_record[zone_id] = record[zone_id]
else:
new_record[zone_id] += record[zone_id]
for new_record_key in new_record:
new_record[new_record_key] = new_record[new_record_key]*1.0/num_day
return new_record
#function 1.4
#generate the dateList: [20200101, 20200102, ..., 20211231]
def generate_dateList():
yearList = ["2020","2021"]
monthList = ["0"+str(i+1) for i in range(9)] + ["10","11","12"]
dayList = ["0"+str(i+1) for i in range(9)] + [str(i) for i in range(10,32)]
day_2020_num = [31,29,31,30,31,30,31,31,30,31,30,31]
day_2021_num = [31,28,31,30,31,30,31,31,30,31,30,31]
date_2020, date_2021 = list(), list()
for i in range(12):
for j in range(day_2020_num[i]):
date_2020.append(yearList[0] + monthList[i] + dayList[j])
for j in range(day_2021_num[i]):
date_2021.append(yearList[1] + monthList[i] + dayList[j])
date_2020_2021 = date_2020 + date_2021
return date_2020_2021
#function 1.5
#smooth the mobility (infection) data using the neighborhood average
#under a given window size
#dateList: [20200101, 20200102, ..., 20211231]
def mob_inf_smooth(data, window_size, dateList):
data_copy = copy.copy(data)
data_key_list = list(data_copy.keys())
for data_key in data_key_list:
left = int(max(dateList.index(data_key)-(window_size-1)/2, 0))
right = int(min(dateList.index(data_key)+(window_size-1)/2, len(dateList)-1))
potential_neighbor = dateList[left:right+1]
neighbor_data_key = list(set(data_key_list).intersection(set(potential_neighbor)))
data_average = mob_inf_average_multiple(data_copy, neighbor_data_key)
data[data_key] = data_average
return data
#function 1.6
#set the mobility (infection) of one day as zero
def mob_inf_average_null(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
new_record[i] = 0
return new_record
#function 1.7
#read the mobility data from "mobility_feature_20200201.json"...
#return: all_mobility:{"20200201":{('123','123'):12345,...},...}
#20200201 to 20210620: 506 days
def read_mobility_data(jcode23):
all_mobility = dict()
mobilityFilePath = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/mobility_20210804"
mobilityNameList = os.listdir(mobilityFilePath)
for i in range(len(mobilityNameList)):
day_mobility = dict()
file_name = mobilityNameList[i]
if "20" in file_name:
day = (file_name.split("_")[2]).split(".")[0] #get the day
file_path = mobilityFilePath + '/' + file_name
f = open(file_path,)
df_file = json.load(f) #read the mobility file
f.close()
for key in df_file:
origin, dest = key.split("_")[0], key.split("_")[1]
if origin in jcode23 and dest in jcode23:
if origin == dest:
day_mobility[(origin, dest)] = 0.0 #ignore the inner-zone flow
else:
day_mobility[(origin, dest)] = df_file[key]
all_mobility[day] = day_mobility
#missing data
all_mobility["20201128"] = mob_inf_average(all_mobility,"20201127","20201129")
all_mobility["20210104"] = mob_inf_average(all_mobility, "20210103","20210105")
return all_mobility
##################2.Text#####################
#function 1.8
#get the average of two days' infection records
def text_average(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
zone_record1, zone_record2 = record1[i], record2[i]
new_zone_record = dict()
for j in zone_record1:
if j in zone_record2:
new_zone_record[j] = (zone_record1[j] + zone_record2[j])/2.0
new_record[i] = new_zone_record
return new_record
#function 1.9
#get the average of multiple days' text records
def text_average_multiple(data, keyList):
new_record = dict()
num_day = len(keyList)
for i in range(num_day):
record = data[keyList[i]]
for zone_id in record: #zone_id
if zone_id not in new_record:
new_record[zone_id] = dict()
for j in record[zone_id]: #symptom
if j not in new_record[zone_id]:
new_record[zone_id][j] = record[zone_id][j]
else:
new_record[zone_id][j] += record[zone_id][j]
for zone_id in new_record:
for j in new_record[zone_id]:
new_record[zone_id][j] = new_record[zone_id][j]*1.0/num_day
return new_record
#function 1.10
#smooth the text data using the neighborhood average
#under a given window size
def text_smooth(data, window_size, dateList):
data_copy = copy.copy(data)
data_key_list = list(data_copy.keys())
for data_key in data_key_list:
left = int(max(dateList.index(data_key)-(window_size-1)/2, 0))
right = int(min(dateList.index(data_key)+(window_size-1)/2, len(dateList)-1))
potential_neighbor = dateList[left:right+1]
neighbor_data_key = list(set(data_key_list).intersection(set(potential_neighbor)))
data_average = text_average_multiple(data_copy, neighbor_data_key)
data[data_key] = data_average
return data
#function 1.11
#read the number of user points
def read_point_json():
with open('user_point/mobility_user_point.json') as point1:
user_point1 = json.load(point1)
with open('user_point/mobility_user_point_20210812.json') as point2:
user_point2 = json.load(point2)
user_point_all = dict()
for i in user_point1:
user_point_all[i] = user_point1[i]
for i in user_point2:
user_point_all[i] = user_point2[i]
user_point_all["20201128"] = user_point_all["20201127"] #data missing
user_point_all["20210104"] = user_point_all["20210103"] #data missing
return user_point_all
#function 1.12
#normalize the text search by the number of user points.
def normalize_text_user(all_text, user_point_all):
for day in all_text:
if day in user_point_all:
num_user = user_point_all[day]["num_user"]
all_text_day_new = dict()
all_text_day = all_text[day]
for zone in all_text_day:
if zone not in all_text_day_new:
all_text_day_new[zone] = dict()
for sym in all_text_day[zone]:
all_text_day_new[zone][sym] = all_text_day[zone][sym]*1.0/num_user
all_text[day] = all_text_day_new
return all_text
#function 1.13
#read the text data
#20200201 to 20210620: 506 days
#all_text = {"20200211":{"123":{"code":3,"fever":2,...},...},...}
def read_text_data(jcode23):
all_text = dict()
textFilePath = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/text_20210804"
textNameList = os.listdir(textFilePath)
for i in range(len(textNameList)):
day_text = dict()
file_name = textNameList[i]
if "20" in file_name:
day = (file_name.split("_")[2]).split(".")[0]
file_path = textFilePath + "/" + file_name
f = open(file_path,)
df_file = json.load(f) #read the mobility file
f.close()
new_dict = dict()
for key in df_file:
if key in jcode23:
new_dict[key] = {key1:df_file[key][key1]*1.0*web_search_normalize_ratio for key1 in df_file[key]}
#new_dict[key] = df_file[key]*WEB_SEARCH_RATIO
all_text[day] = new_dict
all_text["20201030"] = text_average(all_text, "20201029", "20201031") #data missing
return all_text
#function 1.14
#perform the min-max normalization for the text data.
def min_max_text_data(all_text,jcode23):
#calculate the min_max
#region_key: sym: [min,max]
text_list = list(['痛み', '頭痛', '咳', '下痢', 'ストレス', '不安', \
'腹痛', 'めまい', '吐き気', '嘔吐', '筋肉痛', '動悸', \
'副鼻腔炎', '発疹', 'くしゃみ', '倦怠感', '寒気', '脱水', \
'中咽頭', '関節痛', '不眠症', '睡眠障害', '鼻漏', '片頭痛', \
'多汗症', 'ほてり', '胸痛', '発汗', '無気力', '呼吸困難', \
'喘鳴', '目の痛み', '体の痛み', '無嗅覚症', '耳の痛み', \
'錯乱', '見当識障害', '胸の圧迫感', '鼻の乾燥', '耳感染症', \
'味覚消失', '上気道感染症', '眼感染症', '食欲減少'])
region_sym_min_max = dict()
for key in jcode23: #initialize
region_sym_min_max[key] = dict()
for sym in text_list:
region_sym_min_max[key][sym] = [1000000,0] #min, max
for day in all_text: #update
for key in jcode23:
for sym in text_list:
if sym in all_text[day][key]:
count = all_text[day][key][sym]
if count < region_sym_min_max[key][sym][0]:
region_sym_min_max[key][sym][0] = count
if count > region_sym_min_max[key][sym][1]:
region_sym_min_max[key][sym][1] = count
#print ("region_sym_min_max",region_sym_min_max)
for key in jcode23: #normalize
for sym in text_list:
min_count,max_count=region_sym_min_max[key][sym][0],region_sym_min_max[key][sym][1]
for day in all_text:
if sym in all_text[day][key]:
if max_count-min_count == 0:
all_text[day][key][sym] = 1
else:
all_text[day][key][sym] = (all_text[day][key][sym]-min_count)*1.0/(max_count-min_count)
#print("all_text[day][key][sym]",all_text[day][key][sym])
return all_text
##################3.Infection#####################
#function 1.15
#read the infection data
#20200331 to 20210620: (1+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20) = 276 + 171 = 447.
#all_infection = {"20200201":{"123":1,"123":2}}
def read_infection_data(jcode23):
all_infection = dict()
infection_path = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/patient_20210725.json"
f = open(infection_path,)
df_file = json.load(f) #read the mobility file
f.close()
for zone_id in df_file:
for one_day in df_file[zone_id]:
daySplit = one_day.split("/")
year, month, day = daySplit[0], daySplit[1], daySplit[2]
if len(month) == 1:
month = "0" + month
if len(day) == 1:
day = "0" + day
new_date = year + month + day
if str(zone_id[0:5]) in jcode23:
if new_date not in all_infection:
all_infection[new_date] = {zone_id[0:5]:df_file[zone_id][one_day]*1.0/infection_normalize_ratio}
else:
all_infection[new_date][zone_id[0:5]] = df_file[zone_id][one_day]*1.0/infection_normalize_ratio
#missing
date_list = [str(20200316+i) for i in range(15)]
for date in date_list:
all_infection[date] = mob_inf_average(all_infection,'20200401','20200401')
all_infection['20200514'] = mob_inf_average(all_infection,'20200513','20200515')
all_infection['20200519'] = mob_inf_average(all_infection,'20200518','20200520')
all_infection['20200523'] = mob_inf_average(all_infection,'20200522','20200524')
all_infection['20200530'] = mob_inf_average(all_infection,'20200529','20200601')
all_infection['20200531'] = mob_inf_average(all_infection,'20200529','20200601')
all_infection['20201231'] = mob_inf_average(all_infection,'20201230','20210101')
all_infection['20210611'] = mob_inf_average(all_infection,'20210610','20210612')
#outlier
all_infection['20200331'] = mob_inf_average(all_infection,'20200401','20200401')
all_infection['20200910'] = mob_inf_average(all_infection,'20200909','20200912')
all_infection['20200911'] = mob_inf_average(all_infection,'20200909','20200912')
all_infection['20200511'] = mob_inf_average(all_infection,'20200510','20200512')
all_infection['20201208'] = mob_inf_average(all_infection,'20201207','20201209')
all_infection['20210208'] = mob_inf_average(all_infection,'20210207','20210209')
all_infection['20210214'] = mob_inf_average(all_infection,'20210213','20210215')
#calculate the subtraction
all_infection_subtraction = dict()
all_infection_subtraction['20200331'] = all_infection['20200331']
all_keys = list(all_infection.keys())
all_keys.sort()
for i in range(len(all_keys)-1):
record = dict()
for j in all_infection[all_keys[i+1]]:
record[j] = all_infection[all_keys[i+1]][j] - all_infection[all_keys[i]][j]
all_infection_subtraction[all_keys[i+1]] = record
return all_infection_subtraction, all_infection
##################4.Preprocess#####################
#function 1.16
#ensemble the mobility, text, and infection.
#all_mobility = {"20200201":{('123','123'):12345,...},...}
#all_text = {"20200201":{"123":{"cold":3,"fever":2,...},...},...}
#all_infection = {"20200316":{"123":1,"123":2}}
#all_x_y = {"0":[[mobility_1,text_1, ..., mobility_x_day,text_x_day], [infection_1,...,infection_y_day],\
#[infection_1,...,infection_x_day]],0}
#x_days, y_days: use x_days to predict y_days
def ensemble(all_mobility, all_text, all_infection, x_days, y_days, all_day_list):
all_x_y = dict()
for j in range(len(all_day_list) - x_days - y_days + 1):
x_sample, y_sample, x_sample_infection = list(), list(), list()
#add the data from all_day_list[0+j] to all_day_list[x_days-1+j]
for k in range(x_days):
day = all_day_list[k + j]
x_sample.append(all_mobility[day])
x_sample.append(all_text[day])
x_sample_infection.append(all_infection[day]) #concatenate with the infection data
#add the data from all_day_list[x_days+j] to all_day_list[x_days+y_day-1+j]
for k in range(y_days):
day = all_day_list[x_days + k + j]
y_sample.append(all_infection[day])
all_x_y[str(j)] = [x_sample, y_sample, x_sample_infection,j]
return all_x_y
#function 1.17
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n-n_train-n_validate
train_key = [all_x_y[str(i)] for i in range(n_train)]
validate_key = [all_x_y[str(i+n_train)] for i in range(n_validate)]
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key
##function 1.18
#the second data split method
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data_2(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n-n_train-n_validate
train_list, validate_list = list(), list()
train_validate_key = [all_x_y[str(i)] for i in range(n_train+n_validate)]
train_key, validate_key = list(), list()
for i in range(len(train_validate_key)):
if i % 9 == 8:
validate_key.append(all_x_y[str(i)])
validate_list.append(i)
else:
train_key.append(all_x_y[str(i)])
train_list.append(i)
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key, train_list, validate_list
##function 1.19
#the third data split method
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data_3(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n - n_train - n_validate
train_list, validate_list = list(), list()
train_validate_key = [all_x_y[str(i)] for i in range(n_train + n_validate)]
train_key, validate_key = list(), list()
for i in range(len(train_validate_key)):
if (n_train + n_validate-i) % 2 == 0 and (n_train + n_validate-i) <= 2*n_validate:
validate_key.append(all_x_y[str(i)])
validate_list.append(i)
else:
train_key.append(all_x_y[str(i)])
train_list.append(i)
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key, train_list, validate_list
##function 1.20
#find the mobility data starting from the day, which is x_days before the start_date
#start_date = "20200331", x_days = 7
def sort_date(all_mobility, start_date, x_days):
mobility_date_list = list(all_mobility.keys())
mobility_date_list.sort()
idx = mobility_date_list.index(start_date)
mobility_date_cut = mobility_date_list[idx-x_days:]
return mobility_date_cut
#function 1.21
#find the mobility data starting from the day, which is x_days before the start_date,
#ending at the day, which is y_days after the end_date
#start_date = "20200331", x_days = 7
def sort_date_2(all_mobility, start_date, x_days, end_date, y_days):
mobility_date_list = list(all_mobility.keys())
mobility_date_list.sort()
idx = mobility_date_list.index(start_date)
idx2 = mobility_date_list.index(end_date)
mobility_date_cut = mobility_date_list[idx-x_days:idx2+y_days]
return mobility_date_cut
#function 1.22
#get the mappings from zone id to id, text id to id.
#get zone_text_to_idx
def get_zone_text_to_idx(all_infection):
zone_list = list(set(all_infection["20200401"].keys()))
text_list = list(['痛み', '頭痛', '咳', '下痢', 'ストレス', '不安', \
'腹痛', 'めまい'])
zone_list.sort()
zone_dict = {str(zone_list[i]):i for i in range(len(zone_list))}
text_dict = {str(text_list[i]):i for i in range(len(text_list))}
return zone_dict, text_dict
#function 1.23
#change the data format to matrix
#zoneid_to_idx = {"13101":0, "13102":1, ..., "13102":22}
#sym_to_idx = {"cough":0}
#mobility: {('13101', '13101'): 709973, ...}
#text: {'13101': {'痛み': 51,...},...} text
#infection: {'13101': 50, '13102': 137, '13103': 401,...}
#data_type = {"mobility", "text", "infection"}
def to_matrix(zoneid_to_idx, sym_to_idx, input_data, data_type):
n_zone, n_text = len(zoneid_to_idx), len(sym_to_idx)
if data_type == "mobility":
result = np.zeros((n_zone, n_zone))
for key in input_data:
from_id, to_id = key[0], key[1]
from_idx, to_idx = zoneid_to_idx[from_id], zoneid_to_idx[to_id]
result[from_idx][to_idx] += input_data[key]
if data_type == "text":
result = np.zeros((n_zone, n_text))
for key1 in input_data:
for key2 in input_data[key1]:
if key1 in list(zoneid_to_idx.keys()) and key2 in list(sym_to_idx.keys()):
zone_idx, text_idx = zoneid_to_idx[key1], sym_to_idx[key2]
result[zone_idx][text_idx] += input_data[key1][key2]
if data_type == "infection":
result = np.zeros(n_zone)
for key in input_data:
zone_idx = zoneid_to_idx[key]
result[zone_idx] += input_data[key]
return result
#function 1.24
#change the data to the matrix format
def change_to_matrix(data, zoneid_to_idx, sym_to_idx):
data_result = list()
for i in range(len(data)):
combine1, combine2 = list(), list()
combine3 = list() #NEW
mobility_text = data[i][0]
x_infection_all = data[i][2] #the x_days infection data
day_order = data[i][3] #NEW the order of the day
for j in range(round(len(mobility_text)*1.0/2)):
mobility, text = mobility_text[2*j], mobility_text[2*j+1]
x_infection = x_infection_all[j] #NEW
new_mobility = to_matrix(zoneid_to_idx, sym_to_idx, mobility, "mobility")
new_text = to_matrix(zoneid_to_idx, sym_to_idx, text, "text")
combine1.append(new_mobility)
combine1.append(new_text)
new_x_infection = to_matrix(zoneid_to_idx, sym_to_idx, x_infection, "infection") #NEW
combine3.append(new_x_infection) #NEW
for j in range(len(data[i][1])):
infection = data[i][1][j]
new_infection = to_matrix(zoneid_to_idx, sym_to_idx, infection, "infection")
combine2.append(new_infection)
data_result.append([combine1,combine2,combine3,day_order]) #mobility/text; infection_y; infection_x; day_order
return data_result
##################5.learn#####################
#function 1.25
def visual_loss(e_losses, vali_loss, test_loss):
plt.figure(figsize=(4,3), dpi=300)
x = range(len(e_losses))
y1,y2,y3 = copy.copy(e_losses), copy.copy(vali_loss), copy.copy(test_loss)
plt.plot(x,y1,linewidth=1, label="train")
plt.plot(x,y2,linewidth=1, label="validate")
plt.plot(x,y3,linewidth=1, label="test")
plt.legend()
plt.title('Loss decline on entire training/validation/testing data')
plt.xlabel('Epoch')
plt.ylabel('Loss')
#plt.savefig('final_f6.png',bbox_inches = 'tight')
plt.show()
#function 1.26
def visual_loss_train(e_losses):
plt.figure(figsize=(4,3), dpi=300)
x = range(len(e_losses))
y1 = copy.copy(e_losses)
plt.plot(x,y1,linewidth=1, label="train")
plt.legend()
plt.title('Loss decline on entire training data')
plt.xlabel('Epoch')
plt.ylabel('Loss')
#plt.savefig('final_f6.png',bbox_inches = 'tight')
plt.show()
#function 2.1
#normalize each column of the input mobility matrix as one
def normalize_column_one(input_matrix):
column_sum = np.sum(input_matrix, axis=0)
row_num, column_num = len(input_matrix), len(input_matrix[0])
for i in range(row_num):
for j in range(column_num):
input_matrix[i][j] = input_matrix[i][j]*1.0/column_sum[j]
return input_matrix
#function 2.2
#evalute the trained_model on validation or testing data.
def validate_test_process(trained_model, vali_test_data):
criterion = nn.MSELoss()
vali_test_y = [vali_test_data[i][1] for i in range(len(vali_test_data))]
y_real = torch.tensor(vali_test_y)
vali_test_x = [vali_test_data[i] for i in range(len(vali_test_data))]
vali_test_x = convertAdj(vali_test_x)
y_hat = trained_model.run_specGCN_lstm(vali_test_x)
loss = criterion(y_hat.float(), y_real.float()) ###Calculate the loss
return loss, y_hat, y_real
#function 2.3
#convert the mobility matrix in x_batch in a following way
#normalize the flow between zones so that the in-flow of each zone is 1.
def convertAdj(x_batch):
#x_batch:(n_batch, 0/1, 2*i+1)
x_batch_new = copy.copy(x_batch)
n_batch = len(x_batch)
days = round(len(x_batch[0][0])/2)
for i in range(n_batch):
for j in range(days):
mobility_matrix = x_batch[i][0][2*j]
x_batch_new[i][0][2*j] = normalize_column_one(mobility_matrix) #20210818
return x_batch_new
#function 2.4
#a training epoch
def train_epoch_option(model, opt, criterion, trainX_c, trainY_c, batch_size):
model.train()
losses = []
batch_num = 0
for beg_i in range(0, len(trainX_c), batch_size):
batch_num += 1
if batch_num % 16 ==0:
print ("batch_num: ", batch_num, "total batch number: ", int(len(trainX_c)/batch_size))
x_batch = trainX_c[beg_i:beg_i+batch_size]
y_batch = torch.tensor(trainY_c[beg_i:beg_i+batch_size])
opt.zero_grad()
x_batch = convertAdj(x_batch) #conduct the column normalization
y_hat = model.run_specGCN_lstm(x_batch) ###Attention
loss = criterion(y_hat.float(), y_batch.float()) #MSE loss
#opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.data.numpy())
return sum(losses)/float(len(losses)), model
#function 2.5
#multiple training epoch
def train_process(train_data, lr, num_epochs, net, criterion, bs, vali_data, test_data):
opt = optim.Adam(net.parameters(), lr, betas = (0.9,0.999), weight_decay=0)
train_y = [train_data[i][1] for i in range(len(train_data))]
e_losses = list()
e_losses_vali = list()
e_losses_test = list()
time00 = time.time()
for e in range(num_epochs):
time1 = time.time()
print ("current epoch: ",e, "total epoch: ", num_epochs)
number_list = list(range(len(train_data)))
random.shuffle(number_list, random = r)
trainX_sample = [train_data[number_list[j]] for j in range(len(number_list))]
trainY_sample = [train_y[number_list[j]] for j in range(len(number_list))]
loss, net = train_epoch_option(net, opt, criterion, trainX_sample, trainY_sample, bs)
print ("train loss", loss*infection_normalize_ratio*infection_normalize_ratio)
e_losses.append(loss*infection_normalize_ratio*infection_normalize_ratio)
loss_vali, y_hat_vali, y_real_vali = validate_test_process(net, vali_data)
loss_test, y_hat_test, y_real_test = validate_test_process(net, test_data)
e_losses_vali.append(float(loss_vali)*infection_normalize_ratio*infection_normalize_ratio)
e_losses_test.append(float(loss_test)*infection_normalize_ratio*infection_normalize_ratio)
print ("validate loss", float(loss_vali)*infection_normalize_ratio*infection_normalize_ratio)
print ("test loss", float(loss_test)*infection_normalize_ratio*infection_normalize_ratio)
if e>=2 and (e+1)%10 ==0:
visual_loss(e_losses, e_losses_vali, e_losses_test)
visual_loss_train(e_losses)
time2 = time.time()
print ("running time for this epoch:", time2 - time1)
time01 = time.time()
print ("---------------------------------------------------------------")
print ("---------------------------------------------------------------")
#print ("total running time until now:", time01 - time00)
#print ("------------------------------------------------")
#print("specGCN_weight", net.specGCN.layer1.W)
#print("specGCN_weight_grad", net.specGCN.layer1.W.grad)
#print ("------------------------------------------------")
#print("memory decay matrix", net.v)
#print("memory decay matrix grad", net.v.grad)
#print ("------------------------------------------------")
#print ("lstm weight", net.lstm.all_weights[0][0])
#print ("lstm weight grad", net.lstm.all_weights[0][0].grad)
#print ("------------------------------------------------")
#print ("fc1.weight", net.fc1.weight)
#print ("fc1 weight grd", net.fc1.weight.grad)
#print ("---------------------------------------------------------------")
#print ("---------------------------------------------------------------")
return e_losses, net
#function 3.1
def read_data():
jcode23 = list(read_tokyo_23()["JCODE"]) #1.1 get the tokyo 23 zone shapefile
all_mobility = read_mobility_data(jcode23) #1.2 read the mobility data
all_text = read_text_data(jcode23) #1.3 read the text data
all_infection, all_infection_cum = read_infection_data(jcode23) #1.4 read the infection data
#smooth the data using 7-days average
window_size = WINDOW_SIZE #20210818
dateList = generate_dateList() #20210818
all_mobility = mob_inf_smooth(all_mobility, window_size, dateList) #20210818
all_infection = mob_inf_smooth(all_infection, window_size, dateList) #20210818
#smooth, user, min-max.
point_json = read_point_json() #20210821
all_text = normalize_text_user(all_text, point_json) #20210821
all_text = text_smooth(all_text, window_size, dateList) #20210818
all_text = min_max_text_data(all_text,jcode23) #20210820
x_days, y_days = X_day, Y_day
mobility_date_cut = sort_date_2(all_mobility, START_DATE, x_days, END_DATE, y_days)
all_x_y = ensemble(all_mobility, all_text, all_infection, x_days, y_days, mobility_date_cut)
train_original, validate_original, test_original, train_list, validation_list =\
split_data_3(all_x_y,train_ratio,validate_ratio)
zone_dict, text_dict = get_zone_text_to_idx(all_infection) #get zone_idx, text_idx
train_x_y = change_to_matrix(train_original, zone_dict, text_dict) #get train
print ("train_x_y_shape",len(train_x_y),"train_x_y_shape[0]",len(train_x_y[0]))
validate_x_y = change_to_matrix(validate_original, zone_dict, text_dict) #get validate
test_x_y = change_to_matrix(test_original, zone_dict, text_dict) #get test
print (len(train_x_y)) #300
print (len(train_x_y[0][0])) #14
print (np.shape(train_x_y[0][0][0])) #(23,23)
print (np.shape(train_x_y[0][0][1])) #(23,43)
#print ("---------------------------------finish data reading and preprocessing------------------------------------")
return train_x_y, validate_x_y, test_x_y, all_mobility, all_infection, train_original, validate_original, test_original, train_list, validation_list
#function 3.2
#train the model
def model_train(train_x_y, vali_data, test_data):
#3.2.1 define the model
input_dim_1, hidden_dim_1, out_dim_1, hidden_dim_2 = len(train_x_y[0][0][1][1]),\
HIDDEN_DIM_1, OUT_DIM_1, HIDDEN_DIM_2
dropout_1, alpha_1, N = DROPOUT, ALPHA, len(train_x_y[0][0][1])
G_L_Model = SpecGCN_LSTM(X_day, Y_day, input_dim_1, hidden_dim_1, out_dim_1, hidden_dim_2, dropout_1,N) ###Attention
#3.2.2 train the model
num_epochs, batch_size, learning_rate = NUM_EPOCHS, BATCH_SIZE, LEARNING_RATE #model train
criterion = nn.MSELoss()
e_losses, trained_model = train_process(train_x_y, learning_rate, num_epochs, G_L_Model, criterion, batch_size,\
vali_data, test_data)
return e_losses, trained_model
#function 3.3
#evaluate the error on validation (or testing) data.
def validate_test_process(trained_model, vali_test_data):
criterion = nn.MSELoss()
vali_test_y = [vali_test_data[i][1] for i in range(len(vali_test_data))]
y_real = torch.tensor(vali_test_y)
vali_test_x = [vali_test_data[i] for i in range(len(vali_test_data))]
vali_test_x = convertAdj(vali_test_x)
y_hat = trained_model.run_specGCN_lstm(vali_test_x) ###Attention
loss = criterion(y_hat.float(), y_real.float())
return loss, y_hat, y_real
#4.1
#read the data
train_x_y, validate_x_y, test_x_y, all_mobility, all_infection, \
train_original, validate_original, test_original, train_list, validation_list =\
read_data()
#train_x_y, validate_x_y, test_x_y = normalize(train_x_y, validate_x_y, test_x_y)
#train_x_y = train_x_y[0:30]
print (len(train_x_y))
print ("---------------------------------finish data preparation------------------------------------")
train_x_y_shape 210 train_x_y_shape[0] 4 210 42 (23, 23) (23, 8) 210 ---------------------------------finish data preparation------------------------------------
#4.2
#train the model
e_losses, trained_model = model_train(train_x_y, validate_x_y, test_x_y)
print ("---------------------------finish model training-------------------------")
current epoch: 0 total epoch: 100 batch_num: 16 total batch number: 26 train loss 401.1369729414582 validate loss 365.3183579444885 test loss 432.286337018013 running time for this epoch: 13.87576961517334 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 1 total epoch: 100 batch_num: 16 total batch number: 26 train loss 375.7586752719901 validate loss 362.3988479375839 test loss 406.8036377429962 running time for this epoch: 14.154553890228271 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 2 total epoch: 100 batch_num: 16 total batch number: 26 train loss 362.793376531314 validate loss 358.44918340444565 test loss 377.5457292795181 running time for this epoch: 13.857117176055908 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 3 total epoch: 100 batch_num: 16 total batch number: 26 train loss 344.6627251320967 validate loss 352.54787653684616 test loss 341.8327867984772 running time for this epoch: 14.208048343658447 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 4 total epoch: 100 batch_num: 16 total batch number: 26 train loss 326.76321264632327 validate loss 346.54010087251663 test loss 307.1129508316517 running time for this epoch: 14.815295457839966 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 5 total epoch: 100 batch_num: 16 total batch number: 26 train loss 308.19868737900697 validate loss 339.0832245349884 test loss 274.080578237772 running time for this epoch: 15.00539493560791 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 6 total epoch: 100 batch_num: 16 total batch number: 26 train loss 294.11813157990025 validate loss 329.35190945863724 test loss 241.64533242583275 running time for this epoch: 15.028581857681274 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 7 total epoch: 100 batch_num: 16 total batch number: 26 train loss 279.6560650070508 validate loss 315.83499163389206 test loss 213.05888891220093 running time for this epoch: 15.134533643722534 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 8 total epoch: 100 batch_num: 16 total batch number: 26 train loss 268.94710751043425 validate loss 299.2713265120983 test loss 188.41762095689774 running time for this epoch: 15.226862668991089 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 9 total epoch: 100 batch_num: 16 total batch number: 26 train loss 250.70267367280192 validate loss 279.15913611650467 test loss 168.6314307153225
running time for this epoch: 15.682638168334961 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 10 total epoch: 100 batch_num: 16 total batch number: 26 train loss 251.84111790386618 validate loss 257.14827701449394 test loss 150.60318633913994 running time for this epoch: 15.38539743423462 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 11 total epoch: 100 batch_num: 16 total batch number: 26 train loss 226.47470408291727 validate loss 232.24566131830215 test loss 133.62636789679527 running time for this epoch: 15.462918996810913 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 12 total epoch: 100 batch_num: 16 total batch number: 26 train loss 215.9753676365923 validate loss 205.08021116256714 test loss 118.64605359733105 running time for this epoch: 15.533093929290771 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 13 total epoch: 100 batch_num: 16 total batch number: 26 train loss 220.95035571880913 validate loss 184.93838608264923 test loss 108.87616313993931 running time for this epoch: 15.400794982910156 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 14 total epoch: 100 batch_num: 16 total batch number: 26 train loss 194.68448331786527 validate loss 164.87710177898407 test loss 99.14388880133629 running time for this epoch: 15.529341697692871 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 15 total epoch: 100 batch_num: 16 total batch number: 26 train loss 185.4559078950573 validate loss 146.74232341349125 test loss 91.77163243293762 running time for this epoch: 15.437391519546509 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 16 total epoch: 100 batch_num: 16 total batch number: 26 train loss 189.50181857993203 validate loss 130.9127639979124 test loss 83.4105908870697 running time for this epoch: 15.53785228729248 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 17 total epoch: 100 batch_num: 16 total batch number: 26 train loss 169.86982293289014 validate loss 117.35087260603905 test loss 78.22628132998943 running time for this epoch: 16.37969946861267 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 18 total epoch: 100 batch_num: 16 total batch number: 26 train loss 166.92248397265323 validate loss 106.97860270738602 test loss 73.08672647923231 running time for this epoch: 15.624189853668213 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 19 total epoch: 100 batch_num: 16 total batch number: 26 train loss 156.88836264113587 validate loss 96.63913398981094 test loss 69.15267556905746
running time for this epoch: 15.897958993911743 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 20 total epoch: 100 batch_num: 16 total batch number: 26 train loss 150.17531484503436 validate loss 87.93131448328495 test loss 64.677270129323 running time for this epoch: 15.6527419090271 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 21 total epoch: 100 batch_num: 16 total batch number: 26 train loss 145.3471602872014 validate loss 81.63411170244217 test loss 61.9190139696002 running time for this epoch: 15.64637279510498 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 22 total epoch: 100 batch_num: 16 total batch number: 26 train loss 141.54786519982196 validate loss 76.37091912329197 test loss 58.97723138332367 running time for this epoch: 15.705616235733032 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 23 total epoch: 100 batch_num: 16 total batch number: 26 train loss 135.96409414377476 validate loss 70.76610345393419 test loss 56.26370199024677 running time for this epoch: 15.744271993637085 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 24 total epoch: 100 batch_num: 16 total batch number: 26 train loss 131.490644481447 validate loss 66.68413057923317 test loss 55.077201686799526 running time for this epoch: 15.746338605880737 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 25 total epoch: 100 batch_num: 16 total batch number: 26 train loss 154.28513598938784 validate loss 63.49924951791763 test loss 53.95018495619297 running time for this epoch: 15.814890146255493 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 26 total epoch: 100 batch_num: 16 total batch number: 26 train loss 123.7047782719687 validate loss 59.95720624923706 test loss 50.029391422867775 running time for this epoch: 15.76757550239563 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 27 total epoch: 100 batch_num: 16 total batch number: 26 train loss 136.4527512514205 validate loss 58.21011960506439 test loss 48.72069228440523 running time for this epoch: 15.836560249328613 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 28 total epoch: 100 batch_num: 16 total batch number: 26 train loss 122.38698778674006 validate loss 56.82697519659996 test loss 46.75617441534996 running time for this epoch: 15.78766131401062 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 29 total epoch: 100 batch_num: 16 total batch number: 26 train loss 115.82250034229624 validate loss 55.05374167114496 test loss 45.360890217125416
running time for this epoch: 16.12172222137451 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 30 total epoch: 100 batch_num: 16 total batch number: 26 train loss 113.93859689296394 validate loss 54.35987841337919 test loss 44.982307590544224 running time for this epoch: 15.927704572677612 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 31 total epoch: 100 batch_num: 16 total batch number: 26 train loss 124.13732586773459 validate loss 53.55408415198326 test loss 43.95779222249985 running time for this epoch: 15.895678997039795 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 32 total epoch: 100 batch_num: 16 total batch number: 26 train loss 108.82774635252576 validate loss 53.11660002917051 test loss 43.295249342918396 running time for this epoch: 16.066495180130005 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 33 total epoch: 100 batch_num: 16 total batch number: 26 train loss 105.6121256098979 validate loss 53.13653498888016 test loss 42.06427838653326 running time for this epoch: 16.03439497947693 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 34 total epoch: 100 batch_num: 16 total batch number: 26 train loss 104.82050487081762 validate loss 52.53156181424856 test loss 42.0406274497509 running time for this epoch: 15.91681694984436 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 35 total epoch: 100 batch_num: 16 total batch number: 26 train loss 104.3631285601468 validate loss 52.80188750475645 test loss 41.860807687044144 running time for this epoch: 15.902109622955322 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 36 total epoch: 100 batch_num: 16 total batch number: 26 train loss 102.84207381859973 validate loss 52.709365263581276 test loss 42.23901778459549 running time for this epoch: 15.858245372772217 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 37 total epoch: 100 batch_num: 16 total batch number: 26 train loss 113.33859008219507 validate loss 53.3883273601532 test loss 41.01751372218132 running time for this epoch: 15.820679187774658 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 38 total epoch: 100 batch_num: 16 total batch number: 26 train loss 97.86960091096934 validate loss 55.16550503671169 test loss 40.03941547125578 running time for this epoch: 15.843587160110474 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 39 total epoch: 100 batch_num: 16 total batch number: 26 train loss 108.65417679909754 validate loss 55.15444092452526 test loss 40.3381884098053
running time for this epoch: 16.20094609260559 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 40 total epoch: 100 batch_num: 16 total batch number: 26 train loss 95.68061205317025 validate loss 55.823554284870625 test loss 39.47267308831215 running time for this epoch: 15.941930532455444 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 41 total epoch: 100 batch_num: 16 total batch number: 26 train loss 108.05658030289192 validate loss 55.67504093050957 test loss 40.44977016746998 running time for this epoch: 15.929491996765137 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 42 total epoch: 100 batch_num: 16 total batch number: 26 train loss 92.82790426233854 validate loss 58.235181495547295 test loss 39.5224429666996 running time for this epoch: 15.971227407455444 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 43 total epoch: 100 batch_num: 16 total batch number: 26 train loss 91.93006885686407 validate loss 58.34268871694803 test loss 39.48812372982502 running time for this epoch: 15.903128862380981 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 44 total epoch: 100 batch_num: 16 total batch number: 26 train loss 91.0104301119982 validate loss 58.55949129909277 test loss 40.20447377115488 running time for this epoch: 16.102181673049927 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 45 total epoch: 100 batch_num: 16 total batch number: 26 train loss 90.4422677639458 validate loss 59.48068108409643 test loss 39.756312035024166 running time for this epoch: 16.249284267425537 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 46 total epoch: 100 batch_num: 16 total batch number: 26 train loss 89.89917571414952 validate loss 60.524833388626575 test loss 40.00023007392883 running time for this epoch: 16.48489499092102 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 47 total epoch: 100 batch_num: 16 total batch number: 26 train loss 88.60944705601368 validate loss 61.390926130115986 test loss 39.67647906392813 running time for this epoch: 16.31581950187683 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 48 total epoch: 100 batch_num: 16 total batch number: 26 train loss 90.35104618373292 validate loss 61.25633139163256 test loss 40.41424486786127 running time for this epoch: 16.19239640235901 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 49 total epoch: 100 batch_num: 16 total batch number: 26 train loss 87.63121372019802 validate loss 62.41271737962961 test loss 40.50347022712231
running time for this epoch: 16.517268657684326 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 50 total epoch: 100 batch_num: 16 total batch number: 26 train loss 98.11168640024133 validate loss 63.93613759428263 test loss 39.79811444878578 running time for this epoch: 16.12827777862549 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 51 total epoch: 100 batch_num: 16 total batch number: 26 train loss 86.51708407948414 validate loss 66.51342380791903 test loss 38.56616327539086 running time for this epoch: 16.286683797836304 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 52 total epoch: 100 batch_num: 16 total batch number: 26 train loss 86.47418227391664 validate loss 67.28033069521189 test loss 38.96397305652499 running time for this epoch: 16.281290292739868 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 53 total epoch: 100 batch_num: 16 total batch number: 26 train loss 87.77082276841006 validate loss 67.21921265125275 test loss 38.92057342454791 running time for this epoch: 16.366270780563354 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 54 total epoch: 100 batch_num: 16 total batch number: 26 train loss 84.94367926485009 validate loss 66.61566905677319 test loss 39.9151211604476 running time for this epoch: 16.1788969039917 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 55 total epoch: 100 batch_num: 16 total batch number: 26 train loss 84.47872694685228 validate loss 67.61624943464994 test loss 39.68349192291498 running time for this epoch: 15.988470554351807 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 56 total epoch: 100 batch_num: 16 total batch number: 26 train loss 84.85109812614543 validate loss 68.68939846754074 test loss 39.354502223432064 running time for this epoch: 15.945538759231567 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 57 total epoch: 100 batch_num: 16 total batch number: 26 train loss 83.78856707605775 validate loss 67.68982857465744 test loss 40.13941157609224 running time for this epoch: 15.984057426452637 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 58 total epoch: 100 batch_num: 16 total batch number: 26 train loss 83.34349258802831 validate loss 70.59973198920488 test loss 38.81158772855997 running time for this epoch: 15.921997785568237 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 59 total epoch: 100 batch_num: 16 total batch number: 26 train loss 82.95726212155488 validate loss 70.42977027595043 test loss 39.40660506486893
running time for this epoch: 16.321628093719482 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 60 total epoch: 100 batch_num: 16 total batch number: 26 train loss 82.70993992617284 validate loss 72.0567349344492 test loss 38.518174551427364 running time for this epoch: 16.033020973205566 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 61 total epoch: 100 batch_num: 16 total batch number: 26 train loss 84.12213892572456 validate loss 73.34192283451557 test loss 38.50838867947459 running time for this epoch: 15.992528438568115 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 62 total epoch: 100 batch_num: 16 total batch number: 26 train loss 82.37526450237189 validate loss 73.40784650295973 test loss 38.387749809771776 running time for this epoch: 15.944736242294312 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 63 total epoch: 100 batch_num: 16 total batch number: 26 train loss 88.13710723727665 validate loss 74.19107481837273 test loss 38.163389544934034 running time for this epoch: 15.98640513420105 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 64 total epoch: 100 batch_num: 16 total batch number: 26 train loss 81.3586228630609 validate loss 75.15344768762589 test loss 38.34041766822338 running time for this epoch: 15.953441381454468 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 65 total epoch: 100 batch_num: 16 total batch number: 26 train loss 81.15819116398968 validate loss 76.08785759657621 test loss 38.27946027740836 running time for this epoch: 15.871235370635986 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 66 total epoch: 100 batch_num: 16 total batch number: 26 train loss 81.74983681076102 validate loss 76.43953431397676 test loss 37.85586217418313 running time for this epoch: 15.869899988174438 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 67 total epoch: 100 batch_num: 16 total batch number: 26 train loss 82.30053341119654 validate loss 76.56179368495941 test loss 38.17412070930004 running time for this epoch: 16.025413751602173 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 68 total epoch: 100 batch_num: 16 total batch number: 26 train loss 80.95924932233713 validate loss 76.21743716299534 test loss 38.234591484069824 running time for this epoch: 16.042994260787964 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 69 total epoch: 100 batch_num: 16 total batch number: 26 train loss 80.70780421365743 validate loss 75.79091005027294 test loss 38.37288124486804
running time for this epoch: 16.44331455230713 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 70 total epoch: 100 batch_num: 16 total batch number: 26 train loss 80.2269079118829 validate loss 77.36490573734045 test loss 37.884057965129614 running time for this epoch: 15.8685884475708 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 71 total epoch: 100 batch_num: 16 total batch number: 26 train loss 80.34154466629305 validate loss 79.17272858321667 test loss 37.39427775144577 running time for this epoch: 15.88437557220459 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 72 total epoch: 100 batch_num: 16 total batch number: 26 train loss 80.9479683327178 validate loss 79.52883839607239 test loss 37.826166953891516 running time for this epoch: 15.975945711135864 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 73 total epoch: 100 batch_num: 16 total batch number: 26 train loss 82.7503040070749 validate loss 76.97085849940777 test loss 38.46188308671117 running time for this epoch: 15.880822658538818 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 74 total epoch: 100 batch_num: 16 total batch number: 26 train loss 80.30727458255436 validate loss 75.67427586764097 test loss 38.89917163178325 running time for this epoch: 15.95020842552185 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 75 total epoch: 100 batch_num: 16 total batch number: 26 train loss 79.67406763108794 validate loss 79.36852052807808 test loss 37.296966183930635 running time for this epoch: 16.044362545013428 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 76 total epoch: 100 batch_num: 16 total batch number: 26 train loss 78.96564551629126 validate loss 78.15628312528133 test loss 37.053267005831 running time for this epoch: 16.050203800201416 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 77 total epoch: 100 batch_num: 16 total batch number: 26 train loss 91.09163570597215 validate loss 79.70972917973995 test loss 36.799279041588306 running time for this epoch: 16.029016733169556 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 78 total epoch: 100 batch_num: 16 total batch number: 26 train loss 78.59338641908296 validate loss 81.01075887680054 test loss 35.99388059228659 running time for this epoch: 16.109585285186768 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 79 total epoch: 100 batch_num: 16 total batch number: 26 train loss 79.48214494347296 validate loss 81.13984018564224 test loss 36.07755759730935
running time for this epoch: 16.31273627281189 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 80 total epoch: 100 batch_num: 16 total batch number: 26 train loss 78.590762309937 validate loss 82.03139528632164 test loss 35.51028203219175 running time for this epoch: 16.04347848892212 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 81 total epoch: 100 batch_num: 16 total batch number: 26 train loss 78.52736757033401 validate loss 80.58389648795128 test loss 36.44448472186923 running time for this epoch: 16.080236673355103 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 82 total epoch: 100 batch_num: 16 total batch number: 26 train loss 78.48550483618898 validate loss 78.03201209753752 test loss 37.70284587517381 running time for this epoch: 16.048739433288574 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 83 total epoch: 100 batch_num: 16 total batch number: 26 train loss 78.25976974951723 validate loss 81.0924731194973 test loss 35.698877181857824 running time for this epoch: 15.988569736480713 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 84 total epoch: 100 batch_num: 16 total batch number: 26 train loss 79.19636529145968 validate loss 82.03073404729366 test loss 35.53538117557764 running time for this epoch: 16.04919171333313 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 85 total epoch: 100 batch_num: 16 total batch number: 26 train loss 78.30922431484967 validate loss 80.59956133365631 test loss 36.30792023614049 running time for this epoch: 15.907488822937012 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 86 total epoch: 100 batch_num: 16 total batch number: 26 train loss 77.67637876828236 validate loss 80.54852485656738 test loss 36.63987386971712 running time for this epoch: 15.925904989242554 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 87 total epoch: 100 batch_num: 16 total batch number: 26 train loss 78.50849725030086 validate loss 80.64794354140759 test loss 35.8082982711494 running time for this epoch: 15.330076217651367 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 88 total epoch: 100 batch_num: 16 total batch number: 26 train loss 79.71687863270441 validate loss 82.42295123636723 test loss 35.60900688171387 running time for this epoch: 15.238422393798828 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 89 total epoch: 100 batch_num: 16 total batch number: 26 train loss 77.278770385655 validate loss 78.8944959640503 test loss 35.97936825826764
running time for this epoch: 15.568405866622925 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 90 total epoch: 100 batch_num: 16 total batch number: 26 train loss 77.18503460098334 validate loss 81.45247586071491 test loss 35.653780214488506 running time for this epoch: 15.34430718421936 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 91 total epoch: 100 batch_num: 16 total batch number: 26 train loss 78.05293792410305 validate loss 81.17890916764736 test loss 35.327046643942595 running time for this epoch: 15.393005132675171 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 92 total epoch: 100 batch_num: 16 total batch number: 26 train loss 77.82280656371128 validate loss 81.51805959641933 test loss 35.14571348205209 running time for this epoch: 15.308064699172974 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 93 total epoch: 100 batch_num: 16 total batch number: 26 train loss 76.79803529754281 validate loss 80.02109825611115 test loss 35.53838701918721 running time for this epoch: 15.419700860977173 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 94 total epoch: 100 batch_num: 16 total batch number: 26 train loss 76.69089469817226 validate loss 79.39450442790985 test loss 36.14672925323248 running time for this epoch: 15.379451036453247 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 95 total epoch: 100 batch_num: 16 total batch number: 26 train loss 77.07686456679194 validate loss 80.47173731029034 test loss 35.48799781128764 running time for this epoch: 15.362515926361084 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 96 total epoch: 100 batch_num: 16 total batch number: 26 train loss 76.70930246132667 validate loss 81.42953738570213 test loss 34.89781869575381 running time for this epoch: 15.480377197265625 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 97 total epoch: 100 batch_num: 16 total batch number: 26 train loss 76.2750096058611 validate loss 82.50030688941479 test loss 34.965137019753456 running time for this epoch: 15.3143310546875 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 98 total epoch: 100 batch_num: 16 total batch number: 26 train loss 76.74466620233876 validate loss 80.82902058959007 test loss 35.043798852711916 running time for this epoch: 15.350132465362549 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 99 total epoch: 100 batch_num: 16 total batch number: 26 train loss 76.59885721902053 validate loss 82.26851001381874 test loss 34.510025288909674
running time for this epoch: 15.7138032913208 --------------------------------------------------------------- --------------------------------------------------------------- ---------------------------finish model training-------------------------
#4.3
print (len(train_x_y))
print (len(validate_x_y))
print (len(test_x_y))
#4.3.1 model validation
validation_result, validate_hat, validate_real = validate_test_process(trained_model, validate_x_y)
print ("---------------------------------finish model validation------------------------------------")
print (len(validate_hat))
print (len(validate_real))
#4.3.2 model testing
#4.4. model test
test_result, test_hat, test_real = validate_test_process(trained_model, test_x_y)
print ("---------------------------------finish model testing------------------------------------")
print (len(test_real))
print (len(test_hat))
210 30 60 ---------------------------------finish model validation------------------------------------ 30 30 ---------------------------------finish model testing------------------------------------ 60 60
#5.1 RMSE, MAPE, MAE, RMSLE
def RMSELoss(yhat,y):
return float(torch.sqrt(torch.mean((yhat-y)**2)))
def MAPELoss(yhat,y):
return float(torch.mean(torch.div(torch.abs(yhat-y), y)))
def MAELoss(yhat,y):
return float(torch.mean(torch.div(torch.abs(yhat-y), 1)))
def RMSLELoss(yhat,y):
log_yhat = torch.log(yhat+1)
log_y = torch.log(y+1)
return float(torch.sqrt(torch.mean((log_yhat-log_y)**2)))
#compute RMSE
rmse_validate = list()
rmse_test = list()
for i in range(len(validate_x_y)):
rmse_validate.append(float(RMSELoss(validate_hat[i],validate_real[i])))
for i in range(len(test_x_y)):
rmse_test.append(float(RMSELoss(test_hat[i],test_real[i])))
print ("rmse_validate mean", np.mean(rmse_validate))
print ("rmse_test mean", np.mean(rmse_test))
#compute MAE
mae_validate = list()
mae_test = list()
for i in range(len(validate_x_y)):
mae_validate.append(float(MAELoss(validate_hat[i],validate_real[i])))
for i in range(len(test_x_y)):
mae_test.append(float(MAELoss(test_hat[i],test_real[i])))
print ("mae_validate mean", np.mean(mae_validate))
print ("mae_test mean", np.mean(mae_test))
#show RMSE and MAE together
mae_validate, rmse_validate, mae_test, rmse_test =\
np.array(mae_validate)*infection_normalize_ratio, np.array(rmse_validate)*infection_normalize_ratio,\
np.array(mae_test)*infection_normalize_ratio, np.array(rmse_test)*infection_normalize_ratio
print ("-----------------------------------------")
print ("mae_validate mean", round(np.mean(mae_validate),3), " rmse_validate mean", round(np.mean(rmse_validate),3))
print ("mae_test mean", round(np.mean(mae_test),3), " rmse_test mean", round(np.mean(rmse_test),3))
print ("-----------------------------------------")
rmse_validate mean 0.07805463213108704 rmse_test mean 0.05572519766550367 mae_validate mean 0.05479820137922272 mae_test mean 0.04275543221436755 ----------------------------------------- mae_validate mean 5.48 rmse_validate mean 7.805 mae_test mean 4.276 rmse_test mean 5.573 -----------------------------------------
print (trained_model.v)
Parameter containing:
tensor([0.1604, 0.0831, 0.0438, 0.0032, 0.1429, 0.0678, 0.0641, 0.0308, 0.0450,
0.0561, 0.0012, 0.0002, 0.0469, 0.0372, 0.0117, 0.0457, 0.0575, 0.0785,
0.0084, 0.0024, 0.0004, 0.0121, 0.0020], requires_grad=True)
print(validate_hat[0][Y_day-1])
print(torch.sum(validate_hat[0][Y_day-1]))
print(validate_real[0][Y_day-1])
print(torch.sum(validate_real[0][Y_day-1]))
tensor([0.0422, 0.1507, 0.3527, 0.5436, 0.1893, 0.2907, 0.2042, 0.4066, 0.3169,
0.3683, 0.7374, 0.8928, 0.3791, 0.4451, 0.5284, 0.2628, 0.3218, 0.2103,
0.4943, 0.5779, 0.6635, 0.3360, 0.5802], grad_fn=<SelectBackward>)
tensor(9.2945, grad_fn=<SumBackward0>)
tensor([0.0329, 0.0671, 0.1414, 0.3457, 0.0914, 0.1643, 0.1729, 0.2100, 0.1914,
0.1786, 0.4986, 0.4871, 0.1257, 0.2300, 0.3586, 0.1857, 0.1514, 0.1329,
0.3371, 0.3314, 0.3386, 0.4700, 0.3471], dtype=torch.float64)
tensor(5.5900, dtype=torch.float64)
x = range(len(rmse_validate))
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x, np.array(rmse_validate), 'ro-',linewidth=0.8, markersize=1.2, label='RMSE')
l2 = plt.plot(x, np.array(mae_validate), 'go-',linewidth=0.8, markersize=1.2, label='MAE')
plt.xlabel('Date from the first day of validation',fontsize=12)
plt.ylabel("RMSE/MAE daily new cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend()
plt.grid()
plt.show()
x = range(len(mae_test))
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x, np.array(rmse_test), 'ro-',linewidth=0.8, markersize=1.2, label='RMSE')
l2 = plt.plot(x, np.array(mae_test), 'go-',linewidth=0.5, markersize=1.2, label='MAE')
plt.xlabel('Date from the first day of test',fontsize=12)
plt.ylabel("RMSE/MAE Daily new cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend()
plt.grid()
plt.show()
from scipy import stats
#validate
y_days = Y_day
validate_hat_sum = [float(torch.sum(validate_hat[i][y_days-1])) for i in range(len(validate_hat))]
validate_real_sum = [float(torch.sum(validate_real[i][y_days-1])) for i in range(len(validate_real))]
print ("the correlation between validation: ", stats.pearsonr(validate_hat_sum, validate_real_sum)[0])
#test
test_hat_sum = [float(torch.sum(test_hat[i][y_days-1])) for i in range(len(test_hat))]
test_real_sum = [float(torch.sum(test_real[i][y_days-1])) for i in range(len(test_real))]
print ("the correlation between test: ", stats.pearsonr(test_hat_sum, test_real_sum)[0])
#train
train_result, train_hat, train_real = validate_test_process(trained_model, train_x_y)
train_hat_sum = [float(torch.sum(train_hat[i][0])) for i in range(len(train_hat))]
train_real_sum = [float(torch.sum(train_real[i][0])) for i in range(len(train_real))]
print ("the correlation between train: ", stats.pearsonr(train_hat_sum, train_real_sum)[0])
the correlation between validation: 0.9059964820276863 the correlation between test: 0.5424605127424215 the correlation between train: 0.9705166126805527
y1List = [np.sum(list(train_original[i+1][1][Y_day-1].values())) for i in range(len(train_original)-1)]
y2List = [np.sum(list(validate_original[i][1][Y_day-1].values())) for i in range(len(validate_original))]
y2List_hat = [float(torch.sum(validate_hat[i][Y_day-1])) for i in range(len(validate_hat))]
y3List = [np.sum(list(test_original[i][1][Y_day-1].values())) for i in range(len(test_original))]
y3List_hat = [float(torch.sum(test_hat[i][Y_day-1])) for i in range(len(test_hat))]
#x1 = np.array(range(len(y1List)))
#x2 = np.array([len(y1List)+j for j in range(len(y2List))])
x1 = train_list
x2 = validation_list
x3 = np.array([len(y1List)+len(y2List)+j for j in range(len(y3List))])
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x1[0: len(y1List)], np.array(y1List)*infection_normalize_ratio, 'ro-',linewidth=0.8, markersize=2.0, label='train')
l2 = plt.plot(x2, np.array(y2List)*infection_normalize_ratio, 'go-',linewidth=0.8, markersize=2.0, label='validate')
l3 = plt.plot(x2, np.array(y2List_hat)*infection_normalize_ratio, 'g-',linewidth=2, markersize=0.1, label='validate_predict')
l4 = plt.plot(x3, np.array(y3List)*infection_normalize_ratio, 'bo-',linewidth=0.8, markersize=2, label='test')
l5 = plt.plot(x3, np.array(y3List_hat)*infection_normalize_ratio, 'b-',linewidth=2, markersize=0.1, label='test_predict')
#plt.xlabel('Date from the first day of 2020/4/1',fontsize=12)
plt.ylabel("Daily infection cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
my_x_ticks = list()
summary = 0
my_x_ticks.append(summary)
for i in range(5):
summary += 60
my_x_ticks.append(summary)
plt.xticks(my_x_ticks)
plt.yticks(my_y_ticks)
plt.xticks(fontsize=8)
plt.yticks(fontsize=12)
plt.title("SpectralGCN")
plt.legend()
plt.grid()
#plt.savefig('sg_peak4_21_21_1feature_0005.pdf',bbox_inches = 'tight')
plt.show()
def getPredictionPlot(k):
#location k
x_k = [i for i in range(len(test_real))]
real_k = [test_real[i][y_days-1][k] for i in range(len(test_real))]
predict_k = [test_hat[i][y_days-1][k] for i in range(len(test_hat))]
plt.figure(figsize=(4,2.5), dpi=300)
l1 = plt.plot(x_k, np.array(real_k)*infection_normalize_ratio, 'ro-',linewidth=0.8, markersize=2.0, label='real',alpha = 0.8)
l2 = plt.plot(x_k, np.array(predict_k)*infection_normalize_ratio, 'o-',color='black',linewidth=0.8, markersize=2.0, alpha = 0.8, label='predict')
#plt.xlabel('Date from the first day of 2020/4/1',fontsize=12)
#plt.ylabel("Daily infection cases",fontsize=10)
my_y_ticks = np.arange(0,100,40)
my_x_ticks = list()
summary = 0
my_x_ticks.append(summary)
for i in range(6):
summary += 10
my_x_ticks.append(summary)
plt.xticks(my_x_ticks)
plt.yticks(my_y_ticks)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)
plt.title("Real and predict daily infection for region "+str(k))
plt.legend()
plt.grid()
#plt.savefig('sg_peak4_21_21_1feature_0005.pdf',bbox_inches = 'tight')
plt.show()
for i in range(23):
getPredictionPlot(i)